using System.Text;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Text;

namespace SatelliteRpc.Client.SourceGenerator;

/// <summary>
///  The generator of SatelliteRpc client
/// </summary>
[Generator(LanguageNames.CSharp)]
public class SatelliteRpcClientGenerator : ISourceGenerator
{
    public void Initialize(GeneratorInitializationContext context)
    {
        context.RegisterForSyntaxNotifications(() => new SyntaxReceiver());
    }

    public void Execute(GeneratorExecutionContext context)
    {

        context.LogMessage("Client Generator started");

        if (context.SyntaxReceiver is not SyntaxReceiver receiver)
            return;

        var namespaceNameSet = new HashSet<string>();
        var clientClasses = new List<string>();
        const string namespaceName = "SatelliteRpc.Client";
        
        foreach (var @interface in receiver.CandidateInterfaces)
        {
            // first get the semantic model for the interface, and make sure it's annotated
            var model = context.Compilation.GetSemanticModel(@interface.SyntaxTree);
            if (model.GetDeclaredSymbol(@interface) is not INamedTypeSymbol symbol) continue;
            if (!symbol.GetAttributes().Any(ad =>
                    ad.AttributeClass?.ToDisplayString() == "SatelliteRpc.Client.SatelliteRpcAttribute")) continue;

            var rpcAttribute = symbol.GetAttributes().FirstOrDefault(ad =>
                ad.AttributeClass?.ToDisplayString() == "SatelliteRpc.Client.SatelliteRpcAttribute");

            if (rpcAttribute == null) continue;

            var generateClient = (bool?)rpcAttribute
                .NamedArguments
                .FirstOrDefault(kvp => kvp.Key == "GenerateClient").Value.Value ?? true;
            var generateDependencyInjection = (bool?)rpcAttribute
                .NamedArguments
                .FirstOrDefault(kvp => kvp.Key == "GenerateDependencyInjection").Value.Value ?? true;

            // Skip if both are false
            if (!generateClient && !generateDependencyInjection) continue;

            // below is the code to generate the client
            var serviceName = (string?)rpcAttribute.ConstructorArguments.FirstOrDefault().Value;

            context.LogMessage(serviceName ?? "Null");

            var className = symbol.Name.Substring(1) + "Client"; // Changed the class name
            var interfaceName = symbol.Name;

            if (generateDependencyInjection)
            {
                clientClasses.Add($"{interfaceName},{className}");   
            }
            
            var symbolNamespace = symbol.ContainingNamespace.ToDisplayString();

            namespaceNameSet.Add(symbolNamespace);
            
            var stringBuilder = new StringBuilder();
            
            stringBuilder.AppendLine("// <auto-generated/>");
            stringBuilder.AppendLine("using SatelliteRpc.Client;");
            stringBuilder.AppendLine("using SatelliteRpc.Client.Transport;");
            stringBuilder.AppendLine("using ServerProto;");
            stringBuilder.AppendLine("using System.Threading.Tasks;");
            stringBuilder.AppendLine("using System.Threading;");
            stringBuilder.AppendLine($"using {symbolNamespace};");
            stringBuilder.AppendLine();
            stringBuilder.AppendLine($"namespace {symbolNamespace}");
            stringBuilder.AppendLine("{");
            stringBuilder.AppendLine($"    public class {className} : {interfaceName}");
            stringBuilder.AppendLine("    {");
            stringBuilder.AppendLine("        private readonly ISatelliteRpcClient _client;");
            stringBuilder.AppendLine();
            stringBuilder.AppendLine($"        public {className}(ISatelliteRpcClient client)");
            stringBuilder.AppendLine("        {");
            stringBuilder.AppendLine("            _client = client;");
            stringBuilder.AppendLine("        }");
            
            // Assume all methods return Task or Task<T>
            foreach (var member in symbol.GetMembers().OfType<IMethodSymbol>())
            {
                
                var parameters = string.Join(", ", member.Parameters.Select(p => $"{p.Type} {p.Name}"));
                var callParameters = string.Join(",", member.Parameters.Select(p => p.Name));
                var methodName = member.Name;
                
                var returnType = member.ReturnType.ToString() == "System.Threading.Tasks.Task"
                    ? "Task"
                    : $"Task<{((INamedTypeSymbol)member.ReturnType).TypeArguments[0].Name}>"; // Changed the return type

                var invokeAsync = member.Parameters.Length > 1 
                    ? $"await _client.InvokeAsync<{member.Parameters[0].Type.Name}, {((INamedTypeSymbol)member.ReturnType).TypeArguments[0].Name}>"
                    : "await _client.InvokeAsync";

              
                stringBuilder.AppendLine();
                stringBuilder.AppendLine($"        public async {returnType} {methodName}({parameters})");
                stringBuilder.AppendLine("        {");
                if (returnType.StartsWith("Task<")) 
                    stringBuilder.AppendLine($"            return {invokeAsync}(\"{serviceName}/{methodName}\", {callParameters});");
                else
                    stringBuilder.AppendLine($"            {invokeAsync}(\"{serviceName}/{methodName}\", {callParameters});");
                stringBuilder.AppendLine("        }");

                context.LogMessage($"Generator method: {methodName} success");
            }

            stringBuilder.AppendLine("    }");
            stringBuilder.AppendLine("}");

            context.AddSource($"{className}.g.cs", SourceText.From(stringBuilder.ToString(), Encoding.UTF8));

            context.LogMessage("Client Generator finished");
            
            context.LogMessage("DI Generator started");
            
            // Generate a new class to register clients to DI container
            var registrationClassBuilder = new StringBuilder();
            registrationClassBuilder.AppendLine("// <auto-generated/>");
            registrationClassBuilder.AppendLine("using Microsoft.Extensions.DependencyInjection;");
            foreach (var name in namespaceNameSet)
            {
                registrationClassBuilder.AppendLine($"using {name};");
            }
            registrationClassBuilder.AppendLine();
            registrationClassBuilder.AppendLine($"namespace {namespaceName}");
            registrationClassBuilder.AppendLine("{");
            registrationClassBuilder.AppendLine("    public static class RpcClientServiceCollectionExtensions");
            registrationClassBuilder.AppendLine("    {");
            registrationClassBuilder.AppendLine("        public static IServiceCollection AddAutoGeneratedClients(this IServiceCollection services)");
            registrationClassBuilder.AppendLine("        {");

            for (int i = 0; i < clientClasses.Count; i++)
            {
                registrationClassBuilder.AppendLine($"            services.AddSingleton<{clientClasses[i]}>();");
            }

            registrationClassBuilder.AppendLine("            return services;");
            registrationClassBuilder.AppendLine("        }");
            registrationClassBuilder.AppendLine("    }");
            registrationClassBuilder.AppendLine("}");

            context.AddSource("RpcClientServiceCollectionExtensions.g.cs", SourceText.From(registrationClassBuilder.ToString(), Encoding.UTF8));

            context.LogMessage("DI Generator finished");
        }
    }

    private class SyntaxReceiver : ISyntaxReceiver
    {
        public List<InterfaceDeclarationSyntax> CandidateInterfaces { get; } = new();

        public void OnVisitSyntaxNode(SyntaxNode syntaxNode)
        {
            // We only care about interface declarations
            if (syntaxNode is InterfaceDeclarationSyntax { AttributeLists.Count: > 0 } @interface)
            {
                CandidateInterfaces.Add(@interface);
            }
        }
    }
}
